from os import fork
import numpy as np
from scipy.optimize import linear_sum_assignment
import torch
from torch import nn
from torch.nn import functional as F

from models.gcnn import GraphConvolution
import ipdb
st = ipdb.set_trace


class ConceptEBM(nn.Module):
    """Concept EBM for arbitrary relations."""

    def __init__(self):
        """Initialize layers."""
        super().__init__()
        self.g_net = nn.Sequential(
            nn.Linear(8, 128),
            nn.LeakyReLU(),
            nn.Linear(128, 128),
            nn.LeakyReLU(),
            nn.Linear(128, 1)
        )

    def forward(self, rel_boxes, ref_boxes):
        """
        Forward pass, centers & sizes are tensors, relations is a list.

        Inputs:
            rel_boxes (tensor): subject boxes (B, N_rel, 4), (center, size)
            ref_boxes (tensor): object boxes (B, N_rel, 4), (center, size)
        """
        # Embed object boxes to feature vectors
        subjs = torch.cat((
            rel_boxes[..., :2] - rel_boxes[..., 2:] / 2,
            rel_boxes[..., :2] + rel_boxes[..., 2:] / 2
        ), 2)
        subjs = subjs.unsqueeze(2).repeat(1, 1, ref_boxes.size(1), 1)
        objs = torch.cat((
            ref_boxes[..., :2] - ref_boxes[..., 2:] / 2,
            ref_boxes[..., :2] + ref_boxes[..., 2:] / 2
        ), 2)
        objs = objs.unsqueeze(1).repeat(1, rel_boxes.size(1), 1, 1)
        feats = torch.cat((
            subjs - objs,
            subjs - objs[..., (2, 3, 0, 1)]
        ), 3)
        # Compute energy
        return self.g_net(feats).sum(2).sum(1)


class ConceptEBMShapeOrder(nn.Module):
    """Concept EBM for arbitrary shapes."""

    def __init__(self):
        """Initialize layers."""
        super().__init__()
        self.g_net = nn.Sequential(
            nn.Linear(6, 128),
            nn.LeakyReLU(),
            nn.Linear(128, 128)
        )
        layer = nn.TransformerEncoderLayer(128, 4, 128)
        self.t_net = nn.TransformerEncoder(layer, 1)
        self.f_net = nn.Linear(128, 1)

    def forward(self, boxes, centers, lengths, order=None):
        """
        Forward pass, centers & sizes are tensors, relations is a list.

        Inputs:
            boxes (tensor): subject boxes (B, N, 4), (center, size)
            centers (tensor): center of the shape (B, 2), (x, y)
            lengths (tensor): "length" of the shape (B,)
        """
        # This EBM moves only the boxes
        centers = centers.detach()

        src_key_padding_mask = boxes[..., 2:].sum(-1) < 1e-8  # padding boxes
        num_boxes = (1 - src_key_padding_mask.float()).sum(1).long()
        # Difference with center
        box_centers = boxes[..., :2]
        feats = box_centers - centers.unsqueeze(1)  # (B, N, 2)
        feats = feats / lengths[:, None, None]
        # Consecutive pairwise differences
        forward_inds = torch.as_tensor([
            list(range(1, n)) + [0] + list(range(n, boxes.size(1)))
            for n in num_boxes
        ]).to(feats.device)
        forward_feats = torch.stack([
            torch.gather(feats[..., 0], 1, forward_inds),
            torch.gather(feats[..., 1], 1, forward_inds)
        ], 2)
        backward_inds = torch.as_tensor([
            [n-1] + list(range(0, n-1)) + list(range(n, boxes.size(1)))
            for n in num_boxes
        ]).to(feats.device)
        backward_feats = torch.stack([
            torch.gather(feats[..., 0], 1, backward_inds),
            torch.gather(feats[..., 1], 1, backward_inds)
        ], 2)
        feats = torch.cat([
            feats,
            feats - forward_feats,
            feats - backward_feats
        ], 2)  # (B, N, 6)
        # Compute energy
        feats = self.g_net(feats)
        feats = feats.transpose(0, 1)
        feats = self.t_net(feats, src_key_padding_mask=src_key_padding_mask)
        feats = feats.transpose(0, 1)
        return (
            self.f_net(feats)
            * (1 - src_key_padding_mask.float()).unsqueeze(-1)
        ).sum(1)


class ConceptEBMShape2(nn.Module):
    """Concept EBM for arbitrary shapes."""

    def __init__(self, feat_dim=128):
        """Initialize layer for given feature dimension and heads."""
        super().__init__()
        # Transormer to model context between edges
        self.fc_t = nn.Linear(4, feat_dim)
        layer = nn.TransformerEncoderLayer(feat_dim, 4, feat_dim)
        self.t_net = nn.TransformerEncoder(layer, 1)

        # Layer to score each pair wrt its importance to attention
        self.fc_att = nn.Linear(feat_dim, 1)

        # Projection layers for all nodes
        self.fc_v = nn.Linear(2, feat_dim)
        self.fc_n = nn.Linear(2, feat_dim)

        # Post-attention layers
        self.linear1 = nn.Linear(feat_dim, feat_dim)
        self.dropout = nn.Dropout(0.1)
        self.linear2 = nn.Linear(feat_dim, feat_dim)
        self.norm1 = nn.LayerNorm(feat_dim)
        self.norm2 = nn.LayerNorm(feat_dim)
        self.dropout1 = nn.Dropout(0.1)
        self.dropout2 = nn.Dropout(0.1)
        self.activation = nn.LeakyReLU()

        # Energy part
        layer = nn.TransformerEncoderLayer(feat_dim, 4, feat_dim)
        self.g_net = nn.TransformerEncoder(layer, 1)
        self.f_net = nn.Sequential(
            nn.Linear(feat_dim, feat_dim),
            nn.LeakyReLU(),
            nn.Linear(feat_dim, 1)
        )

    def forward(self, boxes, centers, lengths):
        """
        Forward pass, centers & sizes are tensors, relations is a list.

        Inputs:
            boxes (tensor): subject boxes (B, N, 4), (center, size)
            centers (tensor): center of the shape (B, 2), (x, y)
            lengths (tensor): "length" of the shape (B,)
        """
        att_mask = boxes[..., 2:].sum(-1) < 1e-8  # padding boxes

        # Difference with center
        box_centers = boxes[..., :2]
        node_feats = box_centers - centers.unsqueeze(1)  # (B, N, 2)
        node_feats = node_feats / lengths[:, None, None]

        # Projections
        query = node_feats.unsqueeze(2).repeat(1, 1, node_feats.size(1), 1)
        key = node_feats.unsqueeze(1).repeat(1, node_feats.size(1), 1, 1)
        mask = (
            att_mask.unsqueeze(2).repeat(1, 1, att_mask.size(1))
            | att_mask.unsqueeze(1).repeat(1, att_mask.size(1), 1)
        )

        # Relative distances
        diffs = query - key  # (B, N, N, F)

        # Concatenate with node features
        features = torch.cat((
            diffs,
            node_feats.unsqueeze(2).repeat(1, 1, diffs.size(2), 1)
        ), -1)  # (B, N, N, 2*F)
        feats = features.reshape(-1, features.size(2), features.size(3))
        mask = mask.reshape(-1, mask.size(2))

        # Cross-attention
        print(feats.isnan().any())
        feats = self.fc_t(feats).transpose(0, 1)
        feats = self.t_net(feats, src_key_padding_mask=mask)
        feats = feats.transpose(0, 1)  # (B*N, N, F')

        # Compute the importance of each pair
        p_attn = self.fc_att(feats).squeeze(-1)  # (B*N, N)
        p_attn = p_attn.masked_fill(mask > 0, -10000).softmax(-1)
        p_attn = p_attn.reshape(diffs.size(0), diffs.size(1), -1)  # (B, N, N)

        # Compute values and weighted sum
        nodes = self.fc_v(node_feats)  # (B, N, F')
        value = self.fc_v(node_feats)  # (B, N, F')
        nodes = self.norm1(nodes + self.dropout1(torch.matmul(p_attn, value)))
        src2 = self.linear2(self.dropout(self.activation(self.linear1(nodes))))
        nodes = self.norm2(nodes + self.dropout2(src2))

        # Score each node
        nodes = nodes.transpose(0, 1)
        nodes = self.g_net(nodes, src_key_padding_mask=att_mask)
        nodes = nodes.transpose(0, 1)
        return (
            self.f_net(nodes)
            * (1 - att_mask.float()).unsqueeze(-1)
        ).sum(1)


class ConceptEBMShape(nn.Module):
    """Concept EBM for arbitrary shapes."""

    def __init__(self, feat_dim=128):
        """Initialize layer for given feature dimension and heads."""
        super().__init__()
        # Transormer to model context between edges
        self.fc_t = nn.Linear(4, feat_dim)
        layer = nn.TransformerEncoderLayer(feat_dim, 4, feat_dim)
        self.t_net = nn.TransformerEncoder(layer, 1)

        # Layer to score each pair wrt its importance to attention
        self.fc_att = nn.Sequential(
            nn.Linear(4, feat_dim),
            nn.LeakyReLU(),
            nn.Linear(feat_dim, 1)
        )

        # Projection layers for all nodes
        self.fc_v = nn.Linear(2, feat_dim)
        self.fc_n = nn.Linear(2, feat_dim)

        # Post-attention layers
        self.linear1 = nn.Linear(feat_dim, feat_dim)
        self.dropout = nn.Dropout(0.1)
        self.linear2 = nn.Linear(feat_dim, feat_dim)
        self.norm1 = nn.LayerNorm(feat_dim)
        self.norm2 = nn.LayerNorm(feat_dim)
        self.dropout1 = nn.Dropout(0.1)
        self.dropout2 = nn.Dropout(0.1)
        self.activation = nn.LeakyReLU()

        # Energy part
        layer = nn.TransformerEncoderLayer(feat_dim, 4, feat_dim)
        self.g_net = nn.TransformerEncoder(layer, 1)
        self.f_net = nn.Sequential(
            nn.Linear(feat_dim, feat_dim),
            nn.LeakyReLU(),
            nn.Linear(feat_dim, 1)
        )

    def forward(self, boxes, centers, lengths):
        """
        Forward pass, centers & sizes are tensors, relations is a list.

        Inputs:
            boxes (tensor): subject boxes (B, N, 4), (center, size)
            centers (tensor): center of the shape (B, 2), (x, y)
            lengths (tensor): "length" of the shape (B,)
        """
        att_mask = boxes[..., 2:].sum(-1) < 1e-8  # padding boxes

        # Difference with center
        box_centers = boxes[..., :2]
        node_feats = box_centers - centers.unsqueeze(1)  # (B, N, 2)
        node_feats = node_feats / lengths[:, None, None]

        # Projections
        query = node_feats.unsqueeze(2).repeat(1, 1, node_feats.size(1), 1)
        key = node_feats.unsqueeze(1).repeat(1, node_feats.size(1), 1, 1)

        # Relative distances
        diffs = query - key  # (B, N, N, F)

        # Concatenate with node features
        features = torch.cat((
            diffs,
            node_feats.unsqueeze(2).repeat(1, 1, diffs.size(2), 1)
        ), -1)  # (B, N, N, 2*F)

        # Compute the importance of each pair
        mask = att_mask.unsqueeze(1).repeat(1, att_mask.size(1), 1).long()
        p_attn = self.fc_att(features).squeeze(-1)  # (B, N, N)
        p_attn = p_attn.masked_fill(mask > 0, -10000).softmax(-1)

        # Compute values and weighted sum
        nodes = self.fc_v(node_feats)  # (B, N, F')
        value = self.fc_v(node_feats)  # (B, N, F')
        nodes = self.norm1(nodes + self.dropout1(torch.matmul(p_attn, value)))
        src2 = self.linear2(self.dropout(self.activation(self.linear1(nodes))))
        nodes = self.norm2(nodes + self.dropout2(src2))

        # Score each node
        nodes = nodes.transpose(0, 1)
        nodes = self.g_net(nodes, src_key_padding_mask=att_mask)
        nodes = nodes.transpose(0, 1)
        return (
            self.f_net(nodes)
            * (1 - att_mask.float()).unsqueeze(-1)
        ).sum(1)


class ConceptEBMPose3(nn.Module):
    """Concept EBM for arbitrary shapes with poses."""

    def __init__(self):
        """Initialize layers."""
        super().__init__()
        self.g_net = nn.Sequential(
            nn.Linear(4, 128),
            nn.LeakyReLU(),
            nn.Linear(128, 128),
            nn.LeakyReLU()
        )
        layer = nn.TransformerEncoderLayer(128, 4, 128)
        self.t_net = nn.TransformerEncoder(layer, 1)
        self.f_net = nn.Linear(128, 1)

    def forward(self, points, angles, centers, lengths):
        """
        Forward pass.

        Inputs:
            points (tensor): (B, N, 2), location of each of N shapes
            angles (tensor): (B, N) in [-pi, pi]
            centers (tensor): center of the shape (B, 2), (x, y)
            lengths (tensor): "length" of the shape (B, 1)
        """
        points = points.detach()
        lengths = lengths.flatten()
        src_key_padding_mask = (
            (points.abs().sum(-1) < 1e-8)
            & (angles.abs() < 1e-8)
        )  # padding
        num_boxes = (1 - src_key_padding_mask.float()).sum(1).long()

        # Pose features
        angle_feats = torch.stack((
            torch.cos(np.pi * angles),
            torch.sin(np.pi * angles)
        ), -1)
        vert_feats = torch.stack((
            torch.cos(np.pi * angles - np.pi / 2),
            torch.sin(np.pi * angles - np.pi / 2)
        ), -1)
        diffs = centers.unsqueeze(1) - points  # (B, N, 2)
        u_vecs = diffs / torch.sqrt((diffs ** 2).sum(-1) + 1e-14).unsqueeze(2)

        # Pair-wise
        '''
        forward_inds = torch.as_tensor([
            list(range(1, n)) + [0] + list(range(n, points.size(1)))
            for n in num_boxes
        ]).to(angle_feats.device)
        forward_feats = torch.stack([
            torch.gather(angle_feats[..., 0], 1, forward_inds),
            torch.gather(angle_feats[..., 1], 1, forward_inds)
        ], 2)
        backward_inds = torch.as_tensor([
            [n-1] + list(range(0, n-1)) + list(range(n, points.size(1)))
            for n in num_boxes
        ]).to(angle_feats.device)
        backward_feats = torch.stack([
            torch.gather(angle_feats[..., 0], 1, backward_inds),
            torch.gather(angle_feats[..., 1], 1, backward_inds)
        ], 2)
        '''
        angle_feats = torch.cat([
            angle_feats * u_vecs,
            vert_feats * u_vecs
            # angle_feats - forward_feats,
            # angle_feats - backward_feats
        ], 2)  # (B, N, 6)
        # st()

        # Compute energy
        feats = self.g_net(angle_feats)
        # feats = feats.transpose(0, 1)
        # feats = self.t_net(feats, src_key_padding_mask=src_key_padding_mask)
        # feats = feats.transpose(0, 1)
        return (
            self.f_net(feats)
            * (1 - src_key_padding_mask.float()).unsqueeze(-1)
        ).sum(1)


class ConceptEBMPose(nn.Module):
    """Concept EBM for arbitrary shapes with poses."""

    def __init__(self):
        """Initialize layers."""
        super().__init__()
        self.g_net = nn.Sequential(
            nn.Linear(4, 128),
            nn.LeakyReLU(),
            nn.Linear(128, 128),
            nn.LeakyReLU(),
            nn.Linear(128, 128),
            nn.LeakyReLU()
        )
        self.g1 = nn.Sequential(
            nn.Linear(2, 128),
            nn.LeakyReLU()
        )
        self.g2 = nn.Sequential(
            nn.Linear(2, 128),
            nn.LeakyReLU()
        )
        layer = nn.TransformerEncoderLayer(128, 4, 128)
        self.t_net = nn.TransformerEncoder(layer, 1)
        self.f_net = nn.Linear(128, 1)

    def forward(self, points, tri_centers, tri_lengths, centers, lengths):
        """
        Forward pass.

        Inputs:
            points (tensor): (B, N, 2), moving point of each of N entities
            tri_centers (tensor): (B, N, 2), center of each entity
            tri_lengths (tensor): (B, N), length of each entity's main axis
            centers (tensor): center of the shape (B, 2), (x, y)
            lengths (tensor): "length" of the shape (B, 1)
        """
        lengths = lengths.flatten()
        src_key_padding_mask = (
            (points.abs().sum(-1) < 1e-8)
            & (tri_lengths.abs() < 1e-8)
        )  # padding
        num_boxes = (1 - src_key_padding_mask.float()).sum(1).long()

        # Vector from triangle center to point
        vec1 = (points - tri_centers) / (tri_lengths + 1e-8).unsqueeze(-1)

        # Vector from point to shape center
        pseudo_radii = ((tri_centers - centers.unsqueeze(1)) ** 2).sum(-1)
        pseudo_radii = torch.sqrt(pseudo_radii)
        vec2 = (
            (centers.unsqueeze(1) - points)
            / (pseudo_radii - 0.5*tri_lengths).unsqueeze(-1)
        )  # (B, N, 2)
        # st()

        # Pair-wise
        '''
        forward_inds = torch.as_tensor([
            list(range(1, n)) + [0] + list(range(n, points.size(1)))
            for n in num_boxes
        ]).to(angle_feats.device)
        forward_feats = torch.stack([
            torch.gather(angle_feats[..., 0], 1, forward_inds),
            torch.gather(angle_feats[..., 1], 1, forward_inds)
        ], 2)
        backward_inds = torch.as_tensor([
            [n-1] + list(range(0, n-1)) + list(range(n, points.size(1)))
            for n in num_boxes
        ]).to(angle_feats.device)
        backward_feats = torch.stack([
            torch.gather(angle_feats[..., 0], 1, backward_inds),
            torch.gather(angle_feats[..., 1], 1, backward_inds)
        ], 2)
        '''
        feats = torch.cat([
            vec1,
            vec2,
            # self.g1(vec1), self.g2(vec2)
            # vec1 * vec2
        ], 2)  # (B, N, 6)
        # st()

        # Compute energy
        feats = self.g_net(feats)
        '''
        feats = feats.transpose(0, 1)
        feats = self.t_net(feats, src_key_padding_mask=src_key_padding_mask)
        feats = feats.transpose(0, 1)
        '''
        energy = (
            self.f_net(feats)
            * (1 - src_key_padding_mask.float()).unsqueeze(-1)
        ).sum(1) / num_boxes.unsqueeze(-1)
        return energy


class ConceptEBMTable(nn.Module):
    """Concept EBM for tables."""

    def __init__(self):
        """Initialize layers."""
        super().__init__()
        self.g_net = nn.Sequential(
            nn.Linear(4 + 2*32, 128),
            nn.LeakyReLU(),
            nn.Linear(128, 128),
            nn.LeakyReLU(128),
            nn.Linear(128, 1)
        )
        self.words = ['plate', 'napkin', 'fork', 'knife', 'bowl']
        self.embeddings = nn.Embedding(len(self.words), 32)

    def forward(self, boxes, edges):
        """
        Forward pass.

        Inputs:
            boxes (tensor): boxes (B, N, 4), (center, size)
            edges (tensor): two points of the edges (B, 2, 2), (x, y)
        """
        boxes = torch.cat((
            boxes[..., :2] - boxes[..., 2:] / 2,
            boxes[..., :2] + boxes[..., 2:] / 2
        ), 2)
        diffs = (
            boxes.unsqueeze(2).repeat(1, 1, boxes.size(1), 1)
            - boxes.unsqueeze(1).repeat(1, boxes.size(1), 1, 1)
        )
        inds = torch.arange(len(self.words)).unsqueeze(0).repeat(len(diffs), 1)
        embs = self.embeddings(inds.to(diffs.device))
        embs = torch.cat((
            embs.unsqueeze(2).repeat(1, 1, embs.size(1), 1),
            embs.unsqueeze(1).repeat(1, embs.size(1), 1, 1)
        ), -1)
        feats = torch.cat((diffs, embs), -1)
        # Compute energy
        feats = self.g_net(feats).squeeze(-1)  # (B, N, N)
        attention = (
            torch.ones_like(feats)
            - torch.eye(feats.size(1)).unsqueeze(0)
        ).to(feats.device)
        return (feats * attention).sum(1).sum(1)


class ConceptEBMTablePose(nn.Module):
    """Concept EBM for tables."""

    def __init__(self):
        """Initialize layers."""
        super().__init__()
        self.g_net = nn.Sequential(
            nn.Linear(4, 128),
            nn.LeakyReLU(),
            nn.Linear(128, 128),
            nn.LeakyReLU(128),
            nn.Linear(128, 128),
            nn.LeakyReLU(128),
            nn.Linear(128, 1)
        )
        self.words = ['plate', 'napkin', 'fork', 'knife', 'bowl']
        self.embeddings = nn.Embedding(len(self.words), 32)

    def forward(self, points, centers, sizes, edges):
        """
        Forward pass.

        Inputs:
            points (tensor): (B, N, 2), moving point of each of N entities
            centers (tensor): (B, N, 2), center of each entity
            sizes (tensor): (B, N), length of each entity's main axis
            edges (tensor): two points of the edges (B, 2, 2), (x, y)
        """
        # centers = centers[:, (1, 2, 3)]
        sizes = sizes[..., 0]
        # Unit vectors
        u_vecs = torch.ones((points.size(0), points.size(1), 2))
        u_vecs[..., 0] = 0

        # Vector from triangle center to point
        vec1 = (points - centers) / (sizes + 1e-8).unsqueeze(-1)
        inds = torch.arange(len(self.words))[1:4].unsqueeze(0).repeat(len(vec1), 1)
        embs = self.embeddings(inds.to(vec1.device))

        feats = torch.cat((vec1, u_vecs), -1)
        # Compute energy
        feats = self.g_net(feats).squeeze(-1)  # (B, N, N)
        return feats.mean(-1)


class ConceptEBMShapeUnOrder_(nn.Module):
    """Concept EBM for arbitrary shapes."""

    def __init__(self, nlayers=2):
        """Initialize layers."""
        super().__init__()
        self.g_net = nn.Sequential(
            nn.Linear(4 + 0, 128),
            nn.LeakyReLU(),
            nn.Linear(128, 128)
        )
        self.e_net = nn.Sequential(
            nn.Linear(4 + 0, 128),
            nn.LeakyReLU(),
            nn.Linear(128, 128)
        )
        layer = nn.TransformerEncoderLayer(256, 4, 256)
        self.t_net = nn.TransformerEncoder(layer, nlayers)
        self.f_net = nn.Linear(256, 1)
        # self.pos_emb = nn.Embedding(6, 6)

    def forward(self, boxes, centers, lengths, _n):
        """
        Forward pass.

        Inputs:
            boxes (tensor): boxes (B, N, 4), (center, size)
            centers (tensor): center of the shape (B, 2), (x, y)
            lengths (tensor): "length" of the shape (B,)
        """
        src_key_padding_mask = boxes[..., 2:].sum(-1) < 1e-8  # padding boxes
        num_boxes = (1 - src_key_padding_mask.float()).sum(1).long()

        # Difference with center and normalize with length
        box_centers = boxes[..., :2]
        feats = box_centers - centers.unsqueeze(1)  # (B, N, 2)
        # feats = feats / lengths[:, None, None]

        # Edge features
        edges = feats.unsqueeze(2) - feats.unsqueeze(1).detach()

        # Position features
        # pos_ = torch.arange(6).unsqueeze(0).repeat(len(feats), 1)
        # pos_feats = self.pos_emb(pos_.to(feats.device))
        # pos_feats = torch.eye(6).unsqueeze(0).repeat(len(feats), 1, 1)
        # pos_feats = pos_feats.to(feats.device)

        # Compute features
        feats = self.g_net(torch.cat([
            feats,
            # pos_feats,
            lengths.unsqueeze(1).repeat(1, edges.size(1), 1)
        ], -1))
        edges = self.e_net(torch.cat([
            edges,
            # pos_feats.unsqueeze(2).repeat(1, 1, edges.size(1), 1),
            # pos_feats.unsqueeze(1).repeat(1, edges.size(1), 1, 1),
            lengths.unsqueeze(1).repeat(
                1, feats.size(1), 1
            ).unsqueeze(-2).repeat(1, 1, feats.size(1), 1)
        ], -1))  # (B, N, N, 128)
        edges = edges * (1 - src_key_padding_mask.float().unsqueeze(1).unsqueeze(-1))
        edges = edges * (1 - torch.eye(edges.size(1))).unsqueeze(0).unsqueeze(-1).to(edges.device)
        feats = torch.cat([
            feats,
            edges.sum(2)
        ], -1)

        # Compute energy
        feats = feats.transpose(0, 1)
        feats = self.t_net(feats, src_key_padding_mask=src_key_padding_mask)
        feats = feats.transpose(0, 1)
        return (
            self.f_net(feats).squeeze(-1)
            * (1 - src_key_padding_mask.float())
        ).sum(1) / num_boxes


class ConceptEBMShapeUnOrder2(nn.Module):
    """Concept EBM for arbitrary shapes."""

    def __init__(self):
        """Initialize layers."""
        super().__init__()
        self.g_net = nn.Sequential(
            nn.Linear(2, 128),
            nn.LeakyReLU(),
            nn.Linear(128, 128)
        )
        self.e_net = nn.Sequential(
            nn.Linear(2, 128),
            nn.LeakyReLU(),
            nn.Linear(128, 128)
        )
        layer = nn.TransformerEncoderLayer(256, 4, 256)
        self.t_net = nn.TransformerEncoder(layer, 1)
        self.f_net = nn.Linear(256, 1)
        # self.g_net = nn.Sequential(nn.Linear(2, 128), nn.LeakyReLU())
        # self.gc1 = GraphConvolution(128, 128)
        # self.gc2 = GraphConvolution(128, 128)
        # self.gc3 = GraphConvolution(128, 128)
        # self.f_net = nn.Linear(128, 1)

    def forward(self, boxes, centers, lengths):
        """
        Forward pass, centers & sizes are tensors, relations is a list.

        Inputs:
            boxes (tensor): subject boxes (B, N, 4), (center, size)
            centers (tensor): center of the shape (B, 2), (x, y)
            lengths (tensor): "length" of the shape (B,)
        """
        device = boxes.device
        src_key_padding_mask = boxes[..., 2:].sum(-1) < 1e-8  # padding boxes
        num_boxes = (1 - src_key_padding_mask.float()).sum(1).long()

        # Difference with center and normalize with length
        box_centers = boxes[..., :2]
        cfeats = box_centers - centers.unsqueeze(1)  # (B, N, 2)
        feats = cfeats / lengths[:, None, None]

        # Pairwise differences and adjacency
        dists = cfeats.unsqueeze(1) - cfeats.unsqueeze(2)
        adjacency = 1 / ((dists.detach() ** 2).sum(-1) + 1e-8)  # (B, N, N)
        adjacency[:, range(feats.size(1)), range(feats.size(1))] = -9e15
        # adjacency = adjacency * (1 - src_key_padding_mask.float().unsqueeze(1))
        adjacency[src_key_padding_mask.unsqueeze(1).repeat(1, adjacency.size(-1), 1)] = -9e15
        # st()
        adjacency = (0.1 * adjacency).softmax(-1)
        # neighbors = torch.topk(adjacency, 2, -1).indices  # (B, N, 2)
        # neighbor0_feats = torch.stack([
        #     torch.gather(feats[..., 0], 1, neighbors[..., 0]),
        #     torch.gather(feats[..., 1], 1, neighbors[..., 0])
        # ], 2)
        # neighbor1_feats = torch.stack([
        #     torch.gather(feats[..., 0], 1, neighbors[..., 1]),
        #     torch.gather(feats[..., 1], 1, neighbors[..., 1])
        # ], 2)

        # Edge features
        edges = feats.unsqueeze(2) - feats.unsqueeze(1).detach()
        '''# Consecutive pairwise differences
        forward_inds = torch.as_tensor([
            list(range(1, n)) + [0] + list(range(n, boxes.size(1)))
            for n in num_boxes
        ]).to(feats.device)
        forward_feats = torch.stack([
            torch.gather(feats[..., 0], 1, forward_inds),
            torch.gather(feats[..., 1], 1, forward_inds)
        ], 2)
        backward_inds = torch.as_tensor([
            [n-1] + list(range(0, n-1)) + list(range(n, boxes.size(1)))
            for n in num_boxes
        ]).to(feats.device)
        backward_feats = torch.stack([
            torch.gather(feats[..., 0], 1, backward_inds),
            torch.gather(feats[..., 1], 1, backward_inds)
        ], 2)
        feats = torch.cat([
            feats,
            feats - forward_feats,
            feats - backward_feats
        ], 2)  # (B, N, 6)'''
        # Compute energy
        feats = self.g_net(feats)
        edges = self.e_net(edges)  # (B, N, N, 128)
        # edges = edges * (1 - src_key_padding_mask.float().unsqueeze(1).unsqueeze(-1))
        # edges = edges * (1 - torch.eye(edges.size(1))).unsqueeze(0).unsqueeze(-1).to(edges.device)
        edges = torch.matmul(adjacency.unsqueeze(2), edges)  # .squeeze(2)
        feats = torch.cat([feats, edges.squeeze(2)], -1)
        feats = feats.transpose(0, 1)
        feats = self.t_net(feats, src_key_padding_mask=src_key_padding_mask)
        feats = feats.transpose(0, 1)
        return (
            self.f_net(feats)
            * (1 - src_key_padding_mask.float()).unsqueeze(-1)
        ).sum(1) / num_boxes
        # Pairwise differences and adjacency
        dists = feats.unsqueeze(1) - feats.unsqueeze(2)
        adjacency = 1 / ((dists ** 2).sum(-1) + 1e-8)
        adjacency[:, range(feats.size(1)), range(feats.size(1))] = -9e15
        adjacency = adjacency.softmax(-1)

        # Graph convolutions
        # x = F.dropout(x, self.dropout, training=self.training)
        feats = self.g_net(feats)
        feats = F.leaky_relu(self.gc1(feats, adjacency)) + feats
        feats = F.leaky_relu(self.gc2(feats, adjacency)) + feats
        feats = F.leaky_relu(self.gc2(feats, adjacency)) + feats

        # Compute energy
        return (
            self.f_net(feats)
            * (1 - src_key_padding_mask.float()).unsqueeze(-1)
        ).sum(1) / num_boxes.unsqueeze(-1)


class ConceptEBMShapeOrder2(nn.Module):
    """Concept EBM for arbitrary shapes."""

    def __init__(self):
        """Initialize layers."""
        super().__init__()
        self.g_net = nn.Sequential(
            nn.Linear(2, 128),
            nn.LeakyReLU(),
            nn.Linear(128, 64)
        )
        self.e_net = nn.Sequential(
            nn.Linear(2, 128),
            nn.LeakyReLU(),
            nn.Linear(128, 64)
        )
        layer = nn.TransformerEncoderLayer(64, 4, 64)
        self.et_net = nn.TransformerEncoder(layer, 1)
        layer = nn.TransformerEncoderLayer(128, 4, 128)
        self.t_net = nn.TransformerEncoder(layer, 1)
        self.f_net = nn.Linear(128, 1)

    def forward(self, boxes, centers, lengths):
        """
        Forward pass, centers & sizes are tensors, relations is a list.

        Inputs:
            boxes (tensor): subject boxes (B, N, 4), (center, size)
            centers (tensor): center of the shape (B, 2), (x, y)
            lengths (tensor): "length" of the shape (B,)
        """
        # This EBM moves only the boxes
        centers = centers.detach()

        src_key_padding_mask = boxes[..., 2:].sum(-1) < 1e-8  # padding boxes
        num_boxes = (1 - src_key_padding_mask.float()).sum(1).long()
        # Difference with center
        box_centers = boxes[..., :2]
        feats = box_centers - centers.unsqueeze(1)  # (B, N, 2)
        feats = feats / lengths[:, None, None]
        # Consecutive pairwise differences
        forward_inds = torch.as_tensor([
            list(range(1, n)) + [0] + list(range(n, boxes.size(1)))
            for n in num_boxes
        ]).to(feats.device)
        forward_feats = torch.stack([
            torch.gather(feats[..., 0], 1, forward_inds),
            torch.gather(feats[..., 1], 1, forward_inds)
        ], 2)
        backward_inds = torch.as_tensor([
            [n-1] + list(range(0, n-1)) + list(range(n, boxes.size(1)))
            for n in num_boxes
        ]).to(feats.device)
        backward_feats = torch.stack([
            torch.gather(feats[..., 0], 1, backward_inds),
            torch.gather(feats[..., 1], 1, backward_inds)
        ], 2)
        # e_feats = torch.stack(
        #     [feats - forward_feats, feats - backward_feats], 2
        # ).reshape(-1, 2, 2).transpose(0, 1)
        # e_feats = self.e_net(e_feats)
        # e_feats = self.et_net(
        #     e_feats
        # ).sum(0).reshape(feats.size(0), feats.size(1), 64)

        feats = torch.cat([
            self.g_net(feats),
            self.e_net(feats - forward_feats)
        ], 2)  # (B, N, 6)
        # Compute energy
        # feats = self.g_net(feats)
        feats = feats.transpose(0, 1)
        feats = self.t_net(feats, src_key_padding_mask=src_key_padding_mask)
        feats = feats.transpose(0, 1)
        return (
            self.f_net(feats)
            * (1 - src_key_padding_mask.float()).unsqueeze(-1)
        ).sum(1)


class ConceptEBMShapeUnOrder(nn.Module):
    """Concept EBM for arbitrary shapes."""

    def __init__(self, nlayers=1):
        """Initialize layers."""
        super().__init__()
        self.g_net = nn.Sequential(
            nn.Linear(4, 128),
            nn.LeakyReLU(),
            nn.Linear(128, 256)
        )
        self.e_net = nn.Sequential(
            nn.Linear(2, 128),
            nn.LeakyReLU(),
            nn.Linear(128, 128)
        )
        layer = nn.TransformerEncoderLayer(256, 4, 256)
        self.t_net = nn.TransformerEncoder(layer, nlayers)
        self.f_net = nn.Linear(256, 1)

    def forward(self, boxes, centers, lengths, first):
        """
        Forward pass.

        Inputs:
            boxes (tensor): boxes (B, N, 4), (center, size)
            centers (tensor): center of the shape (B, 2), (x, y)
            lengths (tensor): "length" of the shape (B,)
        """
        src_key_padding_mask = boxes[..., 2:].sum(-1) < 1e-8  # padding boxes
        num_boxes = (1 - src_key_padding_mask.float()).sum(1).long()

        # Difference with center and normalize with length
        box_centers = boxes[..., :2]
        feats = box_centers - centers.unsqueeze(1)  # (B, N, 2)
        feats = feats / lengths[:, None, None]

        # Edge features
        # edges = feats.unsqueeze(2) - feats.unsqueeze(1).detach()
        edges = feats - ((first - centers) / lengths[:, None]).unsqueeze(1)

        # Compute features
        feats = self.g_net(torch.cat([feats, edges], -1))

        # Compute energy
        feats = feats.transpose(0, 1)
        feats = self.t_net(feats, src_key_padding_mask=src_key_padding_mask)
        feats = feats.transpose(0, 1)
        return (
            self.f_net(feats).squeeze(-1)
            * (1 - src_key_padding_mask.float())
        ).sum(1) / num_boxes


class ConceptEBMShapeOrderNext(nn.Module):
    """Concept EBM for arbitrary shapes."""

    def __init__(self):
        """Initialize layers."""
        super().__init__()
        self.is_next_net = nn.Sequential(
            nn.Linear(5, 128),
            nn.LeakyReLU(),
            nn.Linear(128, 128),
            nn.LeakyReLU(),
            nn.Linear(128, 1)
        )
        checkpoint = torch.load(
            'circle_next.pt',
            map_location='cpu'
        )
        state_dict = {
            k.replace('is_next_net.', ''): v
            for k, v in checkpoint["model_state_dict"].items()
            if k.startswith('is_next_net')
        }
        self.is_next_net.load_state_dict(state_dict)
        self.g_net = nn.Sequential(
            nn.Linear(6, 128),
            nn.LeakyReLU(),
            nn.Linear(128, 128)
        )
        layer = nn.TransformerEncoderLayer(128, 4, 128)
        self.t_net = nn.TransformerEncoder(layer, 1)
        self.f_net = nn.Linear(128, 1)

    def _order_ebm(self, feats, norm_length, src_key_padding_mask):
        # feats are (B, N, 2)
        is_next_energies = self.is_next_net(torch.cat([
            feats.unsqueeze(2).repeat(
                1, 1, feats.size(1), 1
            ).reshape(len(feats), feats.size(1)**2, -1),
            feats.unsqueeze(1).repeat(
                1, feats.size(1), 1, 1
            ).reshape(len(feats), feats.size(1)**2, -1),
            norm_length.unsqueeze(1).repeat(1, feats.size(1)**2).unsqueeze(-1)
        ], -1)).squeeze(-1)  # (B, N^2)
        is_next_energies = is_next_energies.reshape(
            -1, feats.size(1), feats.size(1)
        )  # (B, N, N)
        is_next_energies[
            src_key_padding_mask.unsqueeze(1).repeat(1, feats.size(1), 1)
        ] = -1e10
        is_next_energies[
            torch.eye(feats.size(1))[None].repeat(len(feats), 1, 1).bool()
        ] = -1e10
        return is_next_energies

    def _place_ebm(self, feats, src_key_padding_mask):
        # feats are (B, N, 6)
        feats = self.g_net(feats)
        feats = feats.transpose(0, 1)
        feats = self.t_net(feats, src_key_padding_mask=src_key_padding_mask)
        feats = feats.transpose(0, 1)
        return self.f_net(feats).squeeze(-1)  # (B, N)

    def _find_next(self, is_next_energies):
        # is_next_energies is (B, N, N)
        device = is_next_energies.device
        is_next_energies = -is_next_energies.softmax(-1)
        f_inds = torch.stack([
            torch.from_numpy(linear_sum_assignment(b_energy)[1])
            for b_energy in is_next_energies.detach().cpu()
        ]).to(device)  # (B, N)
        b_inds = torch.zeros_like(f_inds).to(device)
        for b in range(len(f_inds)):
            b_inds[b, f_inds[b]] = torch.arange(f_inds.size(1)).to(device)
        # st()
        return f_inds, b_inds

    def forward(self, boxes, centers, lengths, order=None):
        """
        Forward pass, centers & sizes are tensors, relations is a list.

        Inputs:
            boxes (tensor): subject boxes (B, N, 4), (center, size)
            centers (tensor): center of the shape (B, 2), (x, y)
            lengths (tensor): "length" of the shape (B,)
            order (tuple): (forward_inds, backward_inds)
        """
        # This EBM moves only the boxes
        centers = centers.detach()
        lengths = lengths.detach()

        src_key_padding_mask = boxes[..., 2:].sum(-1) < 1e-8  # padding boxes
        num_boxes = (1 - src_key_padding_mask.float()).sum(1).long()  # (B,)

        # Difference with center
        box_centers = boxes[..., :2]
        feats = box_centers - centers.unsqueeze(1)  # (B, N, 2)

        # Order EBM
        if order is None:
            is_next_energies = self._order_ebm(
                feats,
                lengths / num_boxes,
                src_key_padding_mask
            )  # (B, N, N)

            # Find next-previous
            forward_inds, backward_inds = self._find_next(is_next_energies)
            return is_next_energies, forward_inds, backward_inds
        forward_inds, backward_inds = order

        # Consecutive pairwise differences
        feats = feats / lengths[:, None, None]
        forward_feats = torch.stack([
            torch.gather(feats[..., 0], 1, forward_inds),
            torch.gather(feats[..., 1], 1, forward_inds)
        ], 2)
        backward_feats = torch.stack([
            torch.gather(feats[..., 0], 1, backward_inds),
            torch.gather(feats[..., 1], 1, backward_inds)
        ], 2)
        feats = torch.cat([
            feats,
            feats - forward_feats,
            feats - backward_feats,
            # lengths.unsqueeze(1).repeat(1, feats.size(1)).unsqueeze(-1)
        ], 2)  # (B, N, 7)

        # Compute placement energy
        place_energy = self._place_ebm(feats, src_key_padding_mask)
        place_energy = (
            place_energy
            * (1 - src_key_padding_mask.float())
        ).sum(1) / num_boxes
        return place_energy


class ConceptEBMShapeAttention(nn.Module):
    """Concept EBM for arbitrary shapes."""

    def __init__(self):
        """Initialize layers."""
        super().__init__()
        self.node_net = nn.Sequential(
            nn.Linear(2 + 6 + 2, 128),
            nn.LeakyReLU(),
            nn.Linear(128, 128),
            nn.LeakyReLU(),
            nn.Linear(128, 1)
        )
        self.edge_net = nn.Sequential(
            nn.Linear(2 + 2*6 + 2, 128),
            nn.LeakyReLU(),
            nn.Linear(128, 128),
            nn.LeakyReLU(),
            nn.Linear(128, 128),
            nn.LeakyReLU(),
            nn.Linear(128, 1)
        )
        self.pos_net = nn.Sequential(
            nn.Linear(6, 32),
            nn.LeakyReLU(),
            nn.Linear(32, 6)
        )

    def forward(self, boxes, centers, lengths, _n):
        """
        Forward pass.

        Inputs:
            boxes (tensor): boxes (B, N, 4), (center, size)
            centers (tensor): center of the shape (B, 2), (x, y)
            lengths (tensor): "length" of the shape (B,)
        """
        src_key_padding_mask = boxes[..., 2:].sum(-1) < 1e-8  # padding boxes
        num_boxes = (1 - src_key_padding_mask.float()).sum(1).long()

        # Difference with center and normalize with length
        box_centers = boxes[..., :2]
        feats = box_centers - centers.unsqueeze(1)  # (B, N, 2)

        # Position features
        pos_feats = torch.eye(feats.size(1))[None].repeat(len(feats), 1, 1)
        pos_feats = self.pos_net(pos_feats.to(feats.device))

        # Node features
        nodes = self.node_net(torch.cat([
            feats,
            pos_feats,
            lengths.unsqueeze(1).repeat(1, feats.size(1), 1)
        ], -1)).squeeze(-1)  # (B, N)

        # Edge features
        edges = feats.unsqueeze(2) - feats.unsqueeze(1)
        edges = self.edge_net(torch.cat([
            edges,
            pos_feats.unsqueeze(2).repeat(1, 1, edges.size(1), 1),
            pos_feats.unsqueeze(1).repeat(1, edges.size(1), 1, 1),
            lengths.unsqueeze(1).repeat(
                1, feats.size(1), 1
            ).unsqueeze(-2).repeat(1, 1, feats.size(1), 1)
        ], -1)).squeeze(-1)  # (B, N, N)

        # Combine energies
        adjacency = (
            torch.ones_like(edges).to(feats.device)
            - torch.eye(edges.size(1)).unsqueeze(0).to(feats.device)
        )
        return (edges * adjacency).sum(2).mean(1) + nodes.mean(1)


class ConceptEBMShapeMove(nn.Module):
    """Concept EBM for arbitrary shapes."""

    def __init__(self, nlayers=2):
        """Initialize layers."""
        super().__init__()
        self.node_net = nn.Sequential(
            nn.Linear(4 + 0, 128),
            nn.LeakyReLU(),
            nn.Linear(128, 128),
            nn.LeakyReLU(),
            nn.Linear(128, 128),
            nn.LeakyReLU(),
            nn.Linear(128, 1)
        )
        self.edge_net = nn.Sequential(
            nn.Linear(4, 128),
            nn.LeakyReLU(),
            nn.Linear(128, 128),
            nn.LeakyReLU(),
            nn.Linear(128, 128),
            nn.LeakyReLU(),
            nn.Linear(128, 1)
        )

    def forward(self, boxes, centers, lengths, _n):
        """
        Forward pass.

        Inputs:
            boxes (tensor): boxes (B, N, 4), (center, size)
            centers (tensor): center of the shape (B, 2), (x, y)
            lengths (tensor): "length" of the shape (B,)
        """
        src_key_padding_mask = boxes[..., 2:].sum(-1) < 1e-8  # padding boxes
        num_boxes = (1 - src_key_padding_mask.float()).sum(1).long()

        # Difference with center and placement energy
        box_centers = boxes[..., :2]
        feats = box_centers - centers.unsqueeze(1)  # (B, N, 2)
        p_energy = self.node_net(torch.cat([
            feats[:, 0],
            lengths
        ], -1)).squeeze(-1)

        # Edge features and adjacency
        dists = feats.unsqueeze(2) - feats.unsqueeze(1).detach()
        adjacency = 1 / ((dists.detach() ** 2).sum(-1) + 1e-8)  # (B, N, N)
        adjacency[:, range(feats.size(1)), range(feats.size(1))] = -9e15
        adjacency[src_key_padding_mask.unsqueeze(1).repeat(1, adjacency.size(-1), 1)] = -9e15
        neighbors = torch.topk(adjacency, 2, -1).indices  # (B, N, 2)
        neighbor0_feats = torch.stack([
            torch.gather(feats[..., 0], 1, neighbors[..., 0]),
            torch.gather(feats[..., 1], 1, neighbors[..., 0])
        ], 2)
        # neighbor1_feats = torch.stack([
        #     torch.gather(feats[..., 0], 1, neighbors[..., 1]),
        #     torch.gather(feats[..., 1], 1, neighbors[..., 1])
        # ], 2)
        edges = feats - neighbor0_feats
        edges = torch.cat((edges[:, 0], edges[:, 1]), -1)  # (B, 4)
        e_energy = self.edge_net(edges).squeeze(-1)

        return p_energy + e_energy
